import copy
import math
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self, id, all_view, have_view, args, device):
        super(Network, self).__init__()
        self.id = id
        self.all_view = all_view
        self.have_view = have_view
        self.args = args
        self.device = device
        self.encoders = []
        self.decoders = []
        self.imitators = []
        self.w_nets = []

        input_dims = args.input_dims
        output_dim = args.output_dim

        self.centers = nn.Parameter(torch.randn(args.class_num, output_dim, device=device))

        for v in all_view:
            self.encoders.append(nn.Sequential(nn.Linear(input_dims[v], 512),nn.ReLU(),
                                               nn.Linear(512, 128), nn.ReLU(),
                                               nn.Linear(128, output_dim)))
            self.decoders.append(nn.Sequential(nn.Linear(output_dim, 128),nn.ReLU(),
                                               nn.Linear(128, 512), nn.ReLU(),
                                               nn.Linear(512, input_dims[v])))
            self.w_nets.append(nn.Sequential(nn.Linear(len(all_view) + output_dim, 128),nn.ReLU(),
                                               nn.Linear(128, output_dim * output_dim)).to(self.device))
            self.imitators.append(nn.Sequential(nn.Linear(output_dim, 128),nn.ReLU(),
                                               nn.Linear(128, output_dim)))

        self.encoders = nn.ModuleList(self.encoders).to(self.device)
        self.decoders = nn.ModuleList(self.decoders).to(self.device)
        self.imitators = nn.ModuleList(self.imitators).to(self.device)
        self.w_nets = nn.ModuleList(self.w_nets).to(self.device)

        self.cluster = nn.Sequential(
            # nn.Linear(output_dim, 2 * output_dim), nn.ReLU(),
            # nn.Linear(2 * output_dim, output_dim), nn.ReLU(),
            nn.Linear(output_dim, args.class_num, bias=False)
        ).to(self.device)
        # self.cluster = nn.Sequential(
        #     nn.Linear(output_dim, args.class_num, bias=False),
        #     nn.Softmax()
        # ).to(self.device)

    def forward(self, xs, give_view):
        av_num = len(self.all_view)
        gv_num = len(give_view)

        # 编码
        zs = [0] * av_num
        hs = [0] * av_num
        for v in self.all_view:
            zs[v] = F.normalize(self.encoders[v](xs[v])).to(self.device)

        h = 0
        ws = [0] * av_num

        for v in give_view:
            h += zs[v] / gv_num

        # w = 0
        # for v in give_view:
        #     gv_tensor = torch.zeros(av_num)
        #     for v in give_view:
        #         gv_tensor[v] = 1
        #     gv_tensor_expanded = gv_tensor.unsqueeze(0).expand(zs[v].shape[0], -1).cuda()  # (3, 4)
        #     input = torch.cat([zs[v], gv_tensor_expanded], dim=1)  # (3, 14)
        #     output = self.w_nets[v](input)
        #     w += output.view(zs[v].shape[0], zs[v].shape[1], zs[v].shape[1]) # (n, 20, 20)
        # zs_sum = sum(zs)
        # zs_sum_unsqueeze = zs_sum.unsqueeze(1)
        # h = F.normalize(torch.bmm(zs_sum_unsqueeze, w).squeeze(1))

        xrs = [0] * av_num
        for v in self.all_view:
            xrs[v] = self.decoders[v](zs[v]).to(self.device)

        pt = F.softmax(self.cluster(h))

        labels = pt.argmax(dim=1).tolist()

        centers = F.normalize(self.centers)

        return xrs, zs, hs, h, ws, pt, labels, centers
